"""
Wald tests for joint significance of demographic (dZ) coefficients
in gravity models 2b, 2c, 2d, and 2e/2f-ii.

Since we lack the full variance-covariance matrix, we use the conservative
chi-squared approximation: chi2 = sum(t_i^2) with k degrees of freedom.
This assumes uncorrelated regressors and is an upper bound for the true
Wald statistic when regressors are positively correlated.
"""

import pandas as pd
import numpy as np
from scipy import stats

# ── Load data ──
results_path = "/mnt/c/demographics_capital_flows/gravity_bilateral/output/tables/gravity_results.csv"
df = pd.read_csv(results_path)

# ── Define which models and which variable sets to test ──
tests = []

# Models with dZ_1, dZ_2, dZ_3
dz_vars = ["dZ_1", "dZ_2", "dZ_3"]

# 2b: Gravity + Demographics (total bilateral positions)
# 2c: Gravity + Demographics + KAOPEN interactions (test dZ alone, and dZ×KAOPEN alone, and all six jointly)
# 2d: Portfolio Equity, Portfolio Debt, FDI Outward
# 2e: Price controls model
# 2f-ii: Full model with fitted & actual rates

models_basic = [
    "2b: Gravity + Demographics",
    "2d: Portfolio Equity",
    "2d: Portfolio Debt",
    "2d: FDI Outward",
    "2e: Gravity + Demographics + Price Controls",
    "2f-ii: Full Model + Fitted & Actual Rates",
]

for model in models_basic:
    tests.append({
        "model": model,
        "test_label": "H0: all dZ = 0",
        "variables": dz_vars,
    })

# 2c: test dZ coefficients alone
tests.append({
    "model": "2c: Gravity + Demographics + KAOPEN interactions",
    "test_label": "H0: all dZ = 0",
    "variables": dz_vars,
})

# 2c: test dZ×KAOPEN interaction coefficients alone
dz_kaopen_vars = ["dZ_1_x_kaopen_j", "dZ_2_x_kaopen_j", "dZ_3_x_kaopen_j"]
tests.append({
    "model": "2c: Gravity + Demographics + KAOPEN interactions",
    "test_label": "H0: all dZ x KAOPEN = 0",
    "variables": dz_kaopen_vars,
})

# 2c: test all six demographic variables jointly
tests.append({
    "model": "2c: Gravity + Demographics + KAOPEN interactions",
    "test_label": "H0: all dZ & dZ x KAOPEN = 0",
    "variables": dz_vars + dz_kaopen_vars,
})

# ── Run tests ──
output_rows = []

for test in tests:
    model_name = test["model"]
    label = test["test_label"]
    var_list = test["variables"]

    sub = df[(df["model"] == model_name) & (df["variable"].isin(var_list))].copy()

    if len(sub) != len(var_list):
        print(f"WARNING: Expected {len(var_list)} variables for {model_name}, found {len(sub)}")
        continue

    # Get N_obs for this model
    n_obs_row = df[(df["model"] == model_name) & (df["variable"] == "_N_obs")]
    n_obs = int(n_obs_row["coefficient"].values[0]) if len(n_obs_row) > 0 else np.nan

    # Number of restrictions
    k = len(var_list)

    # t-statistics
    t_stats = sub["t_stat"].values
    coeffs = sub["coefficient"].values
    ses = sub["std_error"].values

    # Chi-squared statistic (conservative: assumes uncorrelated)
    chi2_stat = np.sum(t_stats ** 2)
    chi2_p = 1 - stats.chi2.cdf(chi2_stat, df=k)

    # F-statistic approximation: F = chi2 / k
    f_stat = chi2_stat / k
    # Approximate denominator df
    n_params_approx = 8  # rough number of regressors
    denom_df = n_obs - n_params_approx if not np.isnan(n_obs) else np.inf
    f_p = 1 - stats.f.cdf(f_stat, dfn=k, dfd=denom_df)

    output_rows.append({
        "model": model_name,
        "test": label,
        "k_restrictions": k,
        "chi2_stat": round(chi2_stat, 3),
        "chi2_p_value": chi2_p,
        "F_stat": round(f_stat, 3),
        "F_p_value": f_p,
        "N_obs": int(n_obs) if not np.isnan(n_obs) else "",
    })

# ── Save ──
out_df = pd.DataFrame(output_rows)
out_path = "/mnt/c/demographics_capital_flows/gravity_bilateral/output/tables/wald_tests.csv"
out_df.to_csv(out_path, index=False)

# ── Print summary ──
print("=" * 100)
print("WALD TESTS FOR JOINT SIGNIFICANCE OF DEMOGRAPHIC COEFFICIENTS")
print("(Conservative: chi2 = sum(t_i^2), assumes uncorrelated regressors)")
print("=" * 100)

for _, row in out_df.iterrows():
    stars = ""
    if row["chi2_p_value"] < 0.001:
        stars = "***"
    elif row["chi2_p_value"] < 0.01:
        stars = "**"
    elif row["chi2_p_value"] < 0.05:
        stars = "*"

    print(f"\nModel: {row['model']}")
    print(f"  Test:  {row['test']}")
    print(f"  k = {row['k_restrictions']} restrictions, N = {row['N_obs']}")
    print(f"  chi2({row['k_restrictions']}) = {row['chi2_stat']:.3f},  p = {row['chi2_p_value']:.2e}  {stars}")
    print(f"  F({row['k_restrictions']}, {int(row['N_obs']) - 8}) = {row['F_stat']:.3f},  p = {row['F_p_value']:.2e}  {stars}")

print(f"\nResults saved to: {out_path}")
